{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    },
    "colab": {
      "name": "conv_batch_norm_folding.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d_-CWko-hsua",
        "colab_type": "text"
      },
      "source": [
        "Copyright 2019 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "USX-5-Fthsuc",
        "colab_type": "text"
      },
      "source": [
        "# Example of folding keras conv layer with batch norm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bl5UlHSrhsud",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CeOAxwA9hsue",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "np.random.seed(123)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ggKIzRVqhsuk",
        "colab_type": "text"
      },
      "source": [
        "## Model with convolution and batch norm layer definition"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o73ItVSghsul",
        "colab_type": "code",
        "colab": {},
        "outputId": "335a3eb8-9e0a-4e95-e0fa-da02745aa1de"
      },
      "source": [
        "\n",
        "epsilon=0.001\n",
        "inputs = tf.keras.Input(shape=(50, 32, 5), batch_size=4)\n",
        "net = inputs\n",
        "net = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3))(net)\n",
        "net = tf.keras.layers.BatchNormalization(epsilon=epsilon)(net)\n",
        "net = tf.keras.layers.ReLU()(net)\n",
        "net = tf.keras.layers.Flatten()(net)\n",
        "model = tf.keras.Model(inputs, net)\n",
        "model.summary()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"functional_27\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_18 (InputLayer)        [(4, 50, 32, 5)]          0         \n",
            "_________________________________________________________________\n",
            "conv2d_16 (Conv2D)           (4, 48, 30, 2)            92        \n",
            "_________________________________________________________________\n",
            "batch_normalization_13 (Batc (4, 48, 30, 2)            8         \n",
            "_________________________________________________________________\n",
            "re_lu_1 (ReLU)               (4, 48, 30, 2)            0         \n",
            "_________________________________________________________________\n",
            "flatten_14 (Flatten)         (4, 2880)                 0         \n",
            "=================================================================\n",
            "Total params: 100\n",
            "Trainable params: 96\n",
            "Non-trainable params: 4\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "syWcEi7ehsur",
        "colab_type": "code",
        "colab": {},
        "outputId": "ddb0b750-a2a9-402f-ab6e-57f84ff37724"
      },
      "source": [
        "model.layers"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f76b40d9cc0>,\n",
              " <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f76b40d9dd8>,\n",
              " <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f76b40db048>,\n",
              " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x7f76b40db7b8>,\n",
              " <tensorflow.python.keras.layers.core.Flatten at 0x7f76b40db470>]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 149
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b28RX7EMhsuw",
        "colab_type": "text"
      },
      "source": [
        "## Initialize all model weights with random numbers"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Zt_Pm7zhhsux",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# we will set all weights to randmom numbers so that even bias will be non zero \n",
        "# it will help to validated numerical correctness of conv and batch norm fusion\n",
        "all_weights = model.get_weights()\n",
        "for i in range(len(all_weights)):\n",
        "  all_weights[i] = np.random.random(all_weights[i].shape)\n",
        "model.set_weights(all_weights)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T4RR29BDhsu2",
        "colab_type": "text"
      },
      "source": [
        "## Dims of conv layer weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SSNRxArthsu3",
        "colab_type": "code",
        "colab": {},
        "outputId": "0e96fde3-d24b-4acb-c808-dd8f2d758450"
      },
      "source": [
        "ind_conv_layer = 1\n",
        "assert(isinstance(model.layers[ind_conv_layer], tf.keras.layers.Conv2D))\n",
        "conv_weights = model.layers[ind_conv_layer].get_weights()\n",
        "print(\"conv weights shape \" + str(conv_weights[0].shape))\n",
        "print(\"conv bias shape \" + str(conv_weights[1].shape))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "conv weights shape (3, 3, 5, 2)\n",
            "conv bias shape (2,)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "id3mK_cvhsu7",
        "colab_type": "text"
      },
      "source": [
        "## Dims of batch norm layer weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5ModR3tEhsu7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "ind_batch_norm_layer = 2\n",
        "assert(isinstance(model.layers[ind_batch_norm_layer], tf.keras.layers.BatchNormalization))\n",
        "bn_weights = model.layers[ind_batch_norm_layer].get_weights()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GkymeV5QhsvB",
        "colab_type": "code",
        "colab": {},
        "outputId": "47e23866-7945-4740-e9db-a9e7c25b6083"
      },
      "source": [
        "gamma = bn_weights[0]\n",
        "print(\"gamma shape \" + str(gamma.shape))\n",
        "betta = bn_weights[1]\n",
        "print(\"betta shape \" + str(gamma.shape))\n",
        "mean = bn_weights[2]\n",
        "print(\"mean shape \" + str(gamma.shape))\n",
        "variance = bn_weights[3]\n",
        "print(\"variance shape \" + str(gamma.shape))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "gamma shape (2,)\n",
            "betta shape (2,)\n",
            "mean shape (2,)\n",
            "variance shape (2,)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26G2EsMUhsvG",
        "colab_type": "text"
      },
      "source": [
        "## Fuse conv and batch norm weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a9ZHY3OChsvG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "new_conv_weights = np.multiply(conv_weights[0], gamma) / np.sqrt(variance + epsilon)\n",
        "new_bias = betta + np.multiply((conv_weights[1] - mean), gamma) / np.sqrt(variance + epsilon)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mYH__wwJhsvK",
        "colab_type": "text"
      },
      "source": [
        "## Model with folded/fused convolution and batch norm layers"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KS-lVY3fhsvK",
        "colab_type": "code",
        "colab": {},
        "outputId": "0e95b17b-cd59-4821-f0d6-ffed83fede57"
      },
      "source": [
        "inputs = tf.keras.Input(shape=(50, 32, 5), batch_size=4)\n",
        "net = inputs\n",
        "net = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3))(net)\n",
        "net = tf.keras.layers.Flatten()(net)\n",
        "model_fused = tf.keras.Model(inputs, net)\n",
        "model_fused.summary()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"functional_25\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_17 (InputLayer)        [(4, 50, 32, 5)]          0         \n",
            "_________________________________________________________________\n",
            "conv2d_15 (Conv2D)           (4, 48, 30, 2)            92        \n",
            "_________________________________________________________________\n",
            "flatten_13 (Flatten)         (4, 2880)                 0         \n",
            "=================================================================\n",
            "Total params: 92\n",
            "Trainable params: 92\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CvRTkWqWhsvO",
        "colab_type": "text"
      },
      "source": [
        "## Initialize model_fused with fused weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O8DY7zaKhsvO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_weights_fused = model_fused.get_weights()\n",
        "all_weights_fused[0] = new_conv_weights\n",
        "all_weights_fused[1] = new_bias\n",
        "model_fused.set_weights(all_weights_fused)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sZjebQVxhsvT",
        "colab_type": "text"
      },
      "source": [
        "## Validate that model and model_fused produce the same outputs"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2p-xu6LFhsvU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_data = np.random.random(inputs.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6vlcJ73mhsvX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "outputs = model.predict(input_data)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "szVMnesmhsvb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "outputs_fused = model_fused.predict(input_data)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LUMC5mTkhsve",
        "colab_type": "code",
        "colab": {},
        "outputId": "5808d6a9-191f-46ea-c68c-99b6cd13cd9d"
      },
      "source": [
        "np.allclose(outputs, outputs_fused)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 146
        }
      ]
    }
  ]
}
